{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a3b44703",
   "metadata": {},
   "source": [
    "# Train a Pytorch model with SparkXshards"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "416b0a4b",
   "metadata": {},
   "source": [
    "Copyright 2016 The BigDL Authors."
   ]
  },
  {
   "cell_type": "raw",
   "id": "a474a628",
   "metadata": {},
   "source": [
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf304c91",
   "metadata": {},
   "source": [
    "SparkXshards in Orca allows users to process large-scale dataset using existing Python codes in a distributed and data-parallel fashion, as shown below. This notebook is an example of how to train a pytorch model using data of SparkXshards on Orca. \n",
    "\n",
    "It is adapted from [PyTorch Tutorial: How to Develop Deep Learning Models with Python](https://machinelearningmastery.com/pytorch-tutorial-develop-deep-learning-models/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31feb952",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import necessary libraries\n",
    "from torch.nn import Linear\n",
    "from torch.nn import ReLU\n",
    "from torch.nn import Sigmoid\n",
    "from torch.nn import Module\n",
    "from torch.optim import SGD\n",
    "from torch.nn import BCELoss\n",
    "from torch.nn.init import kaiming_uniform_\n",
    "from torch.nn.init import xavier_uniform_\n",
    "\n",
    "import bigdl.orca.data.pandas\n",
    "from bigdl.orca import init_orca_context, stop_orca_context\n",
    "from bigdl.orca.learn.pytorch import Estimator\n",
    "from bigdl.orca.learn.metrics import Accuracy\n",
    "from bigdl.orca.data.transformer import StringIndexer\n",
    "import os\n",
    "\n",
    "os.environ['KMP_DUPLICATE_LIB_OK'] ='True'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fde7f66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# start an OrcaContext\n",
    "sc = init_orca_context(cores=4, memory=\"4g\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a46ad6a7",
   "metadata": {},
   "source": [
    "## Load data in parallel and get general information"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "904060e4",
   "metadata": {},
   "source": [
    "Load data into data_shards, it is a SparkXshards that can be operated on in parallel, here each element of the data_shards is a panda dataframe read from a file on the cluster. Users can distribute local code of `pd.read_csv(dataFile)` using `bigdl.orca.data.pandas.read_csv(datapath)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69fdeb3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_shards = bigdl.orca.data.pandas.read_csv('../ionosphere.csv', header=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6d12d912",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>...</th>\n",
       "      <th>25</th>\n",
       "      <th>26</th>\n",
       "      <th>27</th>\n",
       "      <th>28</th>\n",
       "      <th>29</th>\n",
       "      <th>30</th>\n",
       "      <th>31</th>\n",
       "      <th>32</th>\n",
       "      <th>33</th>\n",
       "      <th>34</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0.99539</td>\n",
       "      <td>-0.05889</td>\n",
       "      <td>0.85243</td>\n",
       "      <td>0.02306</td>\n",
       "      <td>0.83398</td>\n",
       "      <td>-0.37708</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>0.03760</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.51171</td>\n",
       "      <td>0.41078</td>\n",
       "      <td>-0.46168</td>\n",
       "      <td>0.21266</td>\n",
       "      <td>-0.34090</td>\n",
       "      <td>0.42267</td>\n",
       "      <td>-0.54487</td>\n",
       "      <td>0.18641</td>\n",
       "      <td>-0.45300</td>\n",
       "      <td>g</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>-0.18829</td>\n",
       "      <td>0.93035</td>\n",
       "      <td>-0.36156</td>\n",
       "      <td>-0.10868</td>\n",
       "      <td>-0.93597</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>-0.04549</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.26569</td>\n",
       "      <td>-0.20468</td>\n",
       "      <td>-0.18401</td>\n",
       "      <td>-0.19040</td>\n",
       "      <td>-0.11593</td>\n",
       "      <td>-0.16626</td>\n",
       "      <td>-0.06288</td>\n",
       "      <td>-0.13738</td>\n",
       "      <td>-0.02447</td>\n",
       "      <td>b</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>-0.03365</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>0.00485</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>-0.12062</td>\n",
       "      <td>0.88965</td>\n",
       "      <td>0.01198</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.40220</td>\n",
       "      <td>0.58984</td>\n",
       "      <td>-0.22145</td>\n",
       "      <td>0.43100</td>\n",
       "      <td>-0.17365</td>\n",
       "      <td>0.60436</td>\n",
       "      <td>-0.24180</td>\n",
       "      <td>0.56045</td>\n",
       "      <td>-0.38238</td>\n",
       "      <td>g</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>-0.45161</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>0.71216</td>\n",
       "      <td>-1.00000</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.90695</td>\n",
       "      <td>0.51613</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>-0.20099</td>\n",
       "      <td>0.25682</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>-0.32382</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>b</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>-0.02401</td>\n",
       "      <td>0.94140</td>\n",
       "      <td>0.06531</td>\n",
       "      <td>0.92106</td>\n",
       "      <td>-0.23255</td>\n",
       "      <td>0.77152</td>\n",
       "      <td>-0.16399</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.65158</td>\n",
       "      <td>0.13290</td>\n",
       "      <td>-0.53206</td>\n",
       "      <td>0.02431</td>\n",
       "      <td>-0.62197</td>\n",
       "      <td>-0.05707</td>\n",
       "      <td>-0.59573</td>\n",
       "      <td>-0.04608</td>\n",
       "      <td>-0.65697</td>\n",
       "      <td>g</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 35 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   0  1        2        3        4        5        6        7        8  \\\n",
       "0  1  0  0.99539 -0.05889  0.85243  0.02306  0.83398 -0.37708  1.00000   \n",
       "1  1  0  1.00000 -0.18829  0.93035 -0.36156 -0.10868 -0.93597  1.00000   \n",
       "2  1  0  1.00000 -0.03365  1.00000  0.00485  1.00000 -0.12062  0.88965   \n",
       "3  1  0  1.00000 -0.45161  1.00000  1.00000  0.71216 -1.00000  0.00000   \n",
       "4  1  0  1.00000 -0.02401  0.94140  0.06531  0.92106 -0.23255  0.77152   \n",
       "\n",
       "         9  ...       25       26       27       28       29       30  \\\n",
       "0  0.03760  ... -0.51171  0.41078 -0.46168  0.21266 -0.34090  0.42267   \n",
       "1 -0.04549  ... -0.26569 -0.20468 -0.18401 -0.19040 -0.11593 -0.16626   \n",
       "2  0.01198  ... -0.40220  0.58984 -0.22145  0.43100 -0.17365  0.60436   \n",
       "3  0.00000  ...  0.90695  0.51613  1.00000  1.00000 -0.20099  0.25682   \n",
       "4 -0.16399  ... -0.65158  0.13290 -0.53206  0.02431 -0.62197 -0.05707   \n",
       "\n",
       "        31       32       33  34  \n",
       "0 -0.54487  0.18641 -0.45300   g  \n",
       "1 -0.06288 -0.13738 -0.02447   b  \n",
       "2 -0.24180  0.56045 -0.38238   g  \n",
       "3  1.00000 -0.32382  1.00000   b  \n",
       "4 -0.59573 -0.04608 -0.65697   g  \n",
       "\n",
       "[5 rows x 35 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# show the first couple of rows in the data_shards\n",
    "data_shards.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "62d45e04",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# see the num of partitions of data_shards\n",
    "data_shards.num_partitions()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3b8526d9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "351"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# count total number of rows in the data_shards\n",
    "len(data_shards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "de53c793",
   "metadata": {},
   "outputs": [],
   "source": [
    "# columns information of element of data_shards.\n",
    "columns = data_shards.get_schema()['columns']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce183207",
   "metadata": {},
   "source": [
    "##  Encode labels"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef28781d",
   "metadata": {},
   "source": [
    "The labels are in strings. Users can transform the strings into integers using `StringIndexer`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3c1d783e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "createDataFrame from shards attempted Arrow optimization failed as: 'NoneType' object has no attribute 'json',Will try without Arrow optimization\n",
      "2022-11-23 23:49:03 WARN  Utils:66 - Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "create shards from Spark DataFrame attempted Arrow optimization failed as: name 'df' is not defined. Will try without Arrow optimization\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "[Stage 12:>                                                         (0 + 1) / 1]\r",
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "label_encoder = StringIndexer(inputCol=columns[-1])\n",
    "data_shards = label_encoder.fit_transform(data_shards)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6caf0c40",
   "metadata": {},
   "source": [
    "Labels start from 1 so need to be updated to zero based."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3d90bdc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_label_to_zero_base(df):\n",
    "    df['34'] = df['34'] - 1\n",
    "    df = df.astype(\"float32\")\n",
    "    return df\n",
    "data_shards = data_shards.transform_shard(update_label_to_zero_base)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b94806a",
   "metadata": {},
   "source": [
    "## Assemble feature and labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "52ca4dcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_shards = data_shards.assembleFeatureLabelCols(featureCols=list(columns[:-1]),\n",
    "                                                   labelCols=[columns[-1]])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5cf45f4",
   "metadata": {},
   "source": [
    "## Define PyTorch model and train it"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6318a0a",
   "metadata": {},
   "source": [
    "Users can build a PyTorch model as usual and use Orca Estimator to train it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d6e4a182",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# define a MLP model\n",
    "class MLP(Module):\n",
    "    # define model elements\n",
    "    def __init__(self, n_inputs):\n",
    "        super(MLP, self).__init__()\n",
    "        # input to first hidden layer\n",
    "        self.hidden1 = Linear(n_inputs, 10)\n",
    "        kaiming_uniform_(self.hidden1.weight, nonlinearity='relu')\n",
    "        self.act1 = ReLU()\n",
    "        # second hidden layer\n",
    "        self.hidden2 = Linear(10, 8)\n",
    "        kaiming_uniform_(self.hidden2.weight, nonlinearity='relu')\n",
    "        self.act2 = ReLU()\n",
    "        # third hidden layer and output\n",
    "        self.hidden3 = Linear(8, 1)\n",
    "        xavier_uniform_(self.hidden3.weight)\n",
    "        self.act3 = Sigmoid()\n",
    "\n",
    "    # forward propagate input\n",
    "    def forward(self, X):\n",
    "        # input to first hidden layer\n",
    "        X = self.hidden1(X)\n",
    "        X = self.act1(X)\n",
    "        # second hidden layer\n",
    "        X = self.hidden2(X)\n",
    "        X = self.act2(X)\n",
    "        # third hidden layer and output\n",
    "        X = self.hidden3(X)\n",
    "        X = self.act3(X)\n",
    "        return X\n",
    "\n",
    "\n",
    "def model_creator(config):\n",
    "    model = MLP(config[\"n_inputs\"])\n",
    "    model.train()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "eb845acd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define criterion and optimizer\n",
    "def optimizer_creator(model, config):\n",
    "    optimizer = SGD(model.parameters(), lr=config[\"lr\"], momentum=config[\"momentum\"])\n",
    "    return optimizer\n",
    "\n",
    "criterion = BCELoss()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c48782c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "creating: createTorchLoss\n",
      "creating: createTorchOptim\n",
      "creating: createZooKerasAccuracy\n",
      "creating: createEstimator\n"
     ]
    }
   ],
   "source": [
    "# build Orca Estimator\n",
    "orca_estimator = Estimator.from_torch(model=model_creator,\n",
    "                                      optimizer=optimizer_creator,\n",
    "                                      loss=criterion,\n",
    "                                      metrics=[Accuracy()],\n",
    "                                      backend=\"spark\",\n",
    "                                      config={\"n_inputs\": 34,\n",
    "                                              \"lr\": 0.01,\n",
    "                                              \"momentum\": 0.9})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d73b9bc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train the model\n",
    "orca_estimator.fit(data=data_shards, epochs=100, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "348040e1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stopping orca context\n"
     ]
    }
   ],
   "source": [
    "# stop OrcaContext\n",
    "stop_orca_context()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37tf2_x",
   "language": "python",
   "name": "py37tf2_x"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
